{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# default_exp constructor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# model_constructor\n",
    "\n",
    "> Create pytorch model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from nbdev.showdoc import *\n",
    "from fastcore.test import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "import torch.nn as nn\n",
    "import torch\n",
    "from collections import OrderedDict\n",
    "from model_constructor.layers import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Stem"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class Stem(nn.Sequential):\n",
    "    \"\"\"Base stem\"\"\"\n",
    "    def __init__(self, c_in=3,  stem_sizes=[], stem_out=64, \n",
    "            conv_layer=ConvLayer, stride_on=0,    \n",
    "            stem_bn_last=False, bn_1st=True,   \n",
    "            stem_use_pool=True, stem_pool=nn.MaxPool2d(kernel_size=3, stride=2, padding=1), **kwargs): \n",
    "        self.sizes = [c_in] + stem_sizes + [stem_out]\n",
    "        num_layers = len(self.sizes)-1\n",
    "        stem = [(f\"conv_{i}\", conv_layer(self.sizes[i], self.sizes[i+1], \n",
    "                stride=2 if i==stride_on else 1, act=True,\n",
    "                bn_layer=not stem_bn_last if i==num_layers-1 else True, \n",
    "                bn_1st=bn_1st, **kwargs))\n",
    "                    for i in range(num_layers)]\n",
    "        if stem_use_pool: stem += [('pool', stem_pool)] \n",
    "        if stem_bn_last: stem.append(('bn', nn.BatchNorm2d(stem_out)))\n",
    "        super().__init__(OrderedDict(stem))\n",
    "    def extra_repr(self):\n",
    "            return f\"sizes: {self.sizes}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Stem(\n",
       "  sizes: [3, 64]\n",
       "  (conv_0): ConvLayer(\n",
       "    (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (act_fn): ReLU(inplace=True)\n",
       "  )\n",
       "  (pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "stem = Stem()\n",
    "stem"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([8, 64, 32, 32])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xb = torch.randn(8, 3, 128, 128)\n",
    "y = stem(xb)\n",
    "y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Stem(\n",
       "  sizes: [3, 64]\n",
       "  (conv_0): ConvLayer(\n",
       "    (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (act_fn): ReLU(inplace=True)\n",
       "  )\n",
       "  (pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "stem = Stem(use_bn=True)\n",
    "stem"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([8, 64, 32, 32])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xb = torch.randn(8, 3, 128, 128)\n",
    "y = stem(xb)\n",
    "y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Stem(\n",
       "  sizes: [3, 32, 64]\n",
       "  (conv_0): ConvLayer(\n",
       "    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (act_fn): ReLU(inplace=True)\n",
       "  )\n",
       "  (conv_1): ConvLayer(\n",
       "    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (act_fn): ReLU(inplace=True)\n",
       "  )\n",
       "  (pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "stem = Stem(use_bn=True, stem_sizes=[32], stride_on=1)\n",
    "stem"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([8, 64, 32, 32])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xb = torch.randn(8, 3, 128, 128)\n",
    "y = stem(xb)\n",
    "y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert y.shape, torch.Size([64, 64, 32, 32])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Stem(\n",
       "  sizes: [3, 32, 64]\n",
       "  (conv_0): ConvLayer(\n",
       "    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "    (act_fn): ReLU(inplace=True)\n",
       "    (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  )\n",
       "  (conv_1): ConvLayer(\n",
       "    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "    (act_fn): ReLU(inplace=True)\n",
       "    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  )\n",
       "  (pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "stem = Stem(use_bn=True, stem_sizes=[32], stride_on=1, bn_1st=False)\n",
    "stem"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Body"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Body constructor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "class BasicLayer(nn.Sequential):\n",
    "    '''Layer from blocks'''\n",
    "    def __init__(self, block, blocks, ni, nf, expansion, stride, sa=False,**kwargs):\n",
    "        self.ni = ni\n",
    "        self.nf = nf\n",
    "        self.blocks = blocks\n",
    "        self.expansion = expansion\n",
    "        super().__init__(OrderedDict(\n",
    "            [(f'block_{i}', block(ni if i==0 else nf, nf, expansion, \n",
    "                            stride if i==0 else 1,\n",
    "                            sa=sa if i==blocks-1 else False, \n",
    "                                  **kwargs))\n",
    "              for i in range(blocks)]))\n",
    "    def extra_repr(self):\n",
    "        return f'from {self.ni*self.expansion} to {self.nf}, {self.blocks} blocks, expansion {self.expansion}.'                    \n",
    "\n",
    "class Body(nn.Sequential): \n",
    "    '''Constructor for body'''\n",
    "    def __init__(self, block, \n",
    "                 body_in=64, body_out=512,  \n",
    "                 bodylayer=BasicLayer, expansion=1,\n",
    "                 layer_szs=[64,128,256,], blocks=[2,2,2,2],\n",
    "                 sa=False, **kwargs):  \n",
    "        layer_szs = [body_in//expansion] + layer_szs + [body_out]\n",
    "        num_layers = len(layer_szs)-1\n",
    "        layers = [(f\"layer_{i}\", bodylayer(block, blocks[i], \n",
    "                            layer_szs[i], layer_szs[i+1], expansion,\n",
    "                            1 if i==0 else 2,  \n",
    "                            sa=sa if i==0 else False,  \n",
    "                                           **kwargs))\n",
    "                    for i in range(num_layers)]\n",
    "        super().__init__(OrderedDict(layers))\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Body(\n",
       "  (layer_0): BasicLayer(\n",
       "    from 64 to 64, 2 blocks, expansion 1.\n",
       "    (block_0): BasicBlock(\n",
       "      (conv): Sequential(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (merge): Noop()\n",
       "      (act_conn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_1): BasicBlock(\n",
       "      (conv): Sequential(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (merge): Noop()\n",
       "      (act_conn): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       "  (layer_1): BasicLayer(\n",
       "    from 64 to 128, 2 blocks, expansion 1.\n",
       "    (block_0): BasicBlock(\n",
       "      (conv): Sequential(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (downsample): ConvLayer(\n",
       "        (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "        (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (merge): Noop()\n",
       "      (act_conn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_1): BasicBlock(\n",
       "      (conv): Sequential(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (merge): Noop()\n",
       "      (act_conn): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       "  (layer_2): BasicLayer(\n",
       "    from 128 to 256, 2 blocks, expansion 1.\n",
       "    (block_0): BasicBlock(\n",
       "      (conv): Sequential(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (downsample): ConvLayer(\n",
       "        (conv): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "        (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (merge): Noop()\n",
       "      (act_conn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_1): BasicBlock(\n",
       "      (conv): Sequential(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (merge): Noop()\n",
       "      (act_conn): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       "  (layer_3): BasicLayer(\n",
       "    from 256 to 512, 2 blocks, expansion 1.\n",
       "    (block_0): BasicBlock(\n",
       "      (conv): Sequential(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (downsample): ConvLayer(\n",
       "        (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "        (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (merge): Noop()\n",
       "      (act_conn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_1): BasicBlock(\n",
       "      (conv): Sequential(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (merge): Noop()\n",
       "      (act_conn): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "body = Body(BasicBlock)\n",
    "body"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Body(\n",
       "  (layer_0): BasicLayer(\n",
       "    from 64 to 64, 2 blocks, expansion 1.\n",
       "    (block_0): BasicBlock(\n",
       "      (conv): Sequential(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (merge): Noop()\n",
       "      (act_conn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_1): BasicBlock(\n",
       "      (conv): Sequential(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (merge): Noop()\n",
       "      (act_conn): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       "  (layer_1): BasicLayer(\n",
       "    from 64 to 128, 2 blocks, expansion 1.\n",
       "    (block_0): BasicBlock(\n",
       "      (conv): Sequential(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (downsample): ConvLayer(\n",
       "        (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "        (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (merge): Noop()\n",
       "      (act_conn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_1): BasicBlock(\n",
       "      (conv): Sequential(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (merge): Noop()\n",
       "      (act_conn): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       "  (layer_2): BasicLayer(\n",
       "    from 128 to 256, 2 blocks, expansion 1.\n",
       "    (block_0): BasicBlock(\n",
       "      (conv): Sequential(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (downsample): ConvLayer(\n",
       "        (conv): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "        (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (merge): Noop()\n",
       "      (act_conn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_1): BasicBlock(\n",
       "      (conv): Sequential(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (merge): Noop()\n",
       "      (act_conn): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       "  (layer_3): BasicLayer(\n",
       "    from 256 to 512, 2 blocks, expansion 1.\n",
       "    (block_0): BasicBlock(\n",
       "      (conv): Sequential(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (downsample): ConvLayer(\n",
       "        (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "        (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "      (merge): Noop()\n",
       "      (act_conn): ReLU(inplace=True)\n",
       "    )\n",
       "    (block_1): BasicBlock(\n",
       "      (conv): Sequential(\n",
       "        (conv_0): ConvLayer(\n",
       "          (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (act_fn): ReLU(inplace=True)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "        (conv_1): ConvLayer(\n",
       "          (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (merge): Noop()\n",
       "      (act_conn): ReLU(inplace=True)\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#hide\n",
    "body = Body(BasicBlock, bn_1st=False)\n",
    "body"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([8, 512, 4, 4])\n"
     ]
    }
   ],
   "source": [
    "bs = 8\n",
    "xb = torch.randn(bs, 64, 32, 32)\n",
    "y = body(xb)\n",
    "print(y.shape)\n",
    "assert y.shape==torch.Size([bs, 512, 4, 4])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "BasicLayer(\n",
       "  from 64 to 64, 2 blocks, expansion 1.\n",
       "  (block_0): BasicBlock(\n",
       "    (conv): Sequential(\n",
       "      (conv_0): ConvLayer(\n",
       "        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (conv_1): ConvLayer(\n",
       "        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (merge): Noop()\n",
       "    (act_conn): ReLU(inplace=True)\n",
       "  )\n",
       "  (block_1): BasicBlock(\n",
       "    (conv): Sequential(\n",
       "      (conv_0): ConvLayer(\n",
       "        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (act_fn): ReLU(inplace=True)\n",
       "      )\n",
       "      (conv_1): ConvLayer(\n",
       "        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (merge): Noop()\n",
       "    (act_conn): ReLU(inplace=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#hide\n",
    "body = Body(BasicBlock, sa=True)\n",
    "body.layer_0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([8, 512, 4, 4])\n"
     ]
    }
   ],
   "source": [
    "# hide\n",
    "bs = 8\n",
    "xb = torch.randn(bs, 64, 32, 32)\n",
    "y = body(xb)\n",
    "print(y.shape)\n",
    "assert y.shape==torch.Size([bs, 512, 4, 4])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export \n",
    "class Head(nn.Sequential):\n",
    "    '''base head'''\n",
    "    def __init__(self, ni, nf, **kwargs):\n",
    "        super().__init__(OrderedDict(\n",
    "            [('pool', nn.AdaptiveAvgPool2d((1, 1))),\n",
    "             ('flat', Flatten()),\n",
    "             ('fc',   nn.Linear(ni, nf)),\n",
    "             ]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "head = Head(512, 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([8, 10])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bs = 8\n",
    "xb = torch.randn(bs, 512, 4, 4)\n",
    "y = head(xb)\n",
    "y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# class Net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export\n",
    "def init_model(model, nonlinearity='leaky_relu'):\n",
    "    '''Init model'''\n",
    "    for m in model.modules():\n",
    "#             if isinstance(m, nn.Conv2d):\n",
    "            if isinstance(m, (nn.Conv2d, nn.Linear)):\n",
    "                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity=nonlinearity)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# export \n",
    "class Net(nn.Sequential):\n",
    "    '''Constructor for model'''\n",
    "    def __init__(self,  stem=Stem,\n",
    "                 body=Body, block=BasicBlock, sa=False,\n",
    "                 layer_szs=[64,128,256,], blocks=[2,2,2,2],\n",
    "                 head=Head, \n",
    "                 c_in=3,  num_classes=1000, \n",
    "                 body_in=64, body_out=512, expansion=1,\n",
    "                init_fn=init_model, **kwargs):\n",
    "        self.init_model=init_fn\n",
    "        super().__init__(OrderedDict([\n",
    "            ('stem', stem(c_in=c_in,stem_out=body_in, **kwargs)),\n",
    "            ('body', body(block, body_in, body_out, \n",
    "                        layer_szs=layer_szs, blocks=blocks, expansion=expansion, \n",
    "                        sa=sa, **kwargs)),\n",
    "            ('head', head(body_out*expansion, num_classes, **kwargs))\n",
    "            ]))\n",
    "        self.init_model(self)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "Mode fan_in` not supported, please use one of ['fan_in', 'fan_out']",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-24-240c23a72b5e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mNet\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m<ipython-input-23-8faacda0c0fd>\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, stem, body, block, sa, layer_szs, blocks, head, c_in, num_classes, body_in, body_out, expansion, init_fn, **kwargs)\u001b[0m\n\u001b[1;32m     17\u001b[0m             \u001b[0;34m(\u001b[0m\u001b[0;34m'head'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhead\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbody_out\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mexpansion\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_classes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     18\u001b[0m             ]))\n\u001b[0;32m---> 19\u001b[0;31m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minit_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m<ipython-input-22-4c3a9410f356>\u001b[0m in \u001b[0;36minit_model\u001b[0;34m(model, nonlinearity)\u001b[0m\n\u001b[1;32m      4\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0mm\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodules\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mConv2d\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m                 \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minit\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkaiming_normal_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'fan_in`'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnonlinearity\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnonlinearity\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/nn/init.py\u001b[0m in \u001b[0;36mkaiming_normal_\u001b[0;34m(tensor, a, mode, nonlinearity)\u001b[0m\n\u001b[1;32m    343\u001b[0m         \u001b[0;34m>>\u001b[0m\u001b[0;34m>\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minit\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkaiming_normal_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mw\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'fan_out'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnonlinearity\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'relu'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    344\u001b[0m     \"\"\"\n\u001b[0;32m--> 345\u001b[0;31m     \u001b[0mfan\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_calculate_correct_fan\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    346\u001b[0m     \u001b[0mgain\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcalculate_gain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnonlinearity\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    347\u001b[0m     \u001b[0mstd\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgain\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfan\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/nn/init.py\u001b[0m in \u001b[0;36m_calculate_correct_fan\u001b[0;34m(tensor, mode)\u001b[0m\n\u001b[1;32m    275\u001b[0m     \u001b[0mvalid_modes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'fan_in'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'fan_out'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    276\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mmode\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mvalid_modes\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 277\u001b[0;31m         \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Mode {} not supported, please use one of {}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalid_modes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    278\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    279\u001b[0m     \u001b[0mfan_in\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfan_out\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_calculate_fan_in_and_fan_out\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mValueError\u001b[0m: Mode fan_in` not supported, please use one of ['fan_in', 'fan_out']"
     ]
    }
   ],
   "source": [
    "model = Net()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'model' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-25-1f8a688cae5d>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m: name 'model' is not defined"
     ]
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'model' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-26-f906a0217dd0>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mbs_test\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m16\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0mxb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbs_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m128\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m128\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mxb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      4\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'model' is not defined"
     ]
    }
   ],
   "source": [
    "bs_test = 16\n",
    "xb = torch.randn(bs_test, 3, 128, 128)\n",
    "y = model(xb)\n",
    "y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "Mode fan_in` not supported, please use one of ['fan_in', 'fan_out']",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-27-61ce31320a2d>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;31m#hide\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mNet\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbn_1st\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      3\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbody\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayer_0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-23-8faacda0c0fd>\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, stem, body, block, sa, layer_szs, blocks, head, c_in, num_classes, body_in, body_out, expansion, init_fn, **kwargs)\u001b[0m\n\u001b[1;32m     17\u001b[0m             \u001b[0;34m(\u001b[0m\u001b[0;34m'head'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhead\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbody_out\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mexpansion\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_classes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     18\u001b[0m             ]))\n\u001b[0;32m---> 19\u001b[0;31m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minit_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m<ipython-input-22-4c3a9410f356>\u001b[0m in \u001b[0;36minit_model\u001b[0;34m(model, nonlinearity)\u001b[0m\n\u001b[1;32m      4\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0mm\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodules\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mConv2d\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m                 \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minit\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkaiming_normal_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'fan_in`'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnonlinearity\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnonlinearity\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/nn/init.py\u001b[0m in \u001b[0;36mkaiming_normal_\u001b[0;34m(tensor, a, mode, nonlinearity)\u001b[0m\n\u001b[1;32m    343\u001b[0m         \u001b[0;34m>>\u001b[0m\u001b[0;34m>\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minit\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkaiming_normal_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mw\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'fan_out'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnonlinearity\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'relu'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    344\u001b[0m     \"\"\"\n\u001b[0;32m--> 345\u001b[0;31m     \u001b[0mfan\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_calculate_correct_fan\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    346\u001b[0m     \u001b[0mgain\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcalculate_gain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnonlinearity\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0ma\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    347\u001b[0m     \u001b[0mstd\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgain\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfan\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/nn/init.py\u001b[0m in \u001b[0;36m_calculate_correct_fan\u001b[0;34m(tensor, mode)\u001b[0m\n\u001b[1;32m    275\u001b[0m     \u001b[0mvalid_modes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'fan_in'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'fan_out'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    276\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mmode\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mvalid_modes\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 277\u001b[0;31m         \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Mode {} not supported, please use one of {}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmode\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalid_modes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    278\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    279\u001b[0m     \u001b[0mfan_in\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfan_out\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_calculate_fan_in_and_fan_out\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mValueError\u001b[0m: Mode fan_in` not supported, please use one of ['fan_in', 'fan_out']"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "model = Net(bn_1st=False)\n",
    "model.body.layer_0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'model' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-28-c8cf4483beae>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0mbs_test\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m8\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0mxb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbs_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m128\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m128\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mxb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      5\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      6\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mbs_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1000\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34mf\"size expected {bs_test}, 1000\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'model' is not defined"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "bs_test = 8\n",
    "xb = torch.randn(bs_test, 3, 128, 128)\n",
    "y = model(xb)\n",
    "print(y.shape)\n",
    "assert y.shape == torch.Size([bs_test, 1000]), f\"size expected {bs_test}, 1000\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Net(\n",
       "  (stem): Stem(\n",
       "    sizes: [3, 64]\n",
       "    (conv0): ConvLayer(\n",
       "      (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (act_fn): ReLU(inplace=True)\n",
       "    )\n",
       "    (pool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       "  )\n",
       "  (body): Body(\n",
       "    (layer_0): BasicLayer(\n",
       "      from 64 to 64, 2 blocks, expansion 1.\n",
       "      (block_0): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_1): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (layer_1): BasicLayer(\n",
       "      from 64 to 128, 2 blocks, expansion 1.\n",
       "      (block_0): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (downsample): ConvLayer(\n",
       "          (conv): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "          (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_1): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (layer_2): BasicLayer(\n",
       "      from 128 to 256, 2 blocks, expansion 1.\n",
       "      (block_0): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (downsample): ConvLayer(\n",
       "          (conv): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "          (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_1): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "    (layer_3): BasicLayer(\n",
       "      from 256 to 512, 2 blocks, expansion 1.\n",
       "      (block_0): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (downsample): ConvLayer(\n",
       "          (conv): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "          (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "      (block_1): BasicBlock(\n",
       "        (conv): Sequential(\n",
       "          (conv_0): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (act_fn): ReLU(inplace=True)\n",
       "          )\n",
       "          (conv_1): ConvLayer(\n",
       "            (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (merge): Noop()\n",
       "        (act_conn): ReLU(inplace=True)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (head): Head(\n",
       "    (pool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
       "    (flat): Flatten()\n",
       "    (fc): Linear(in_features=512, out_features=1000, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#hide\n",
    "model = Net(sa=True)\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([8, 1000])\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "bs_test = 8\n",
    "xb = torch.randn(bs_test, 3, 128, 128)\n",
    "y = model(xb)\n",
    "print(y.shape)\n",
    "assert y.shape == torch.Size([bs_test, 1000]), f\"size expected {bs_test}, 1000\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# model constructor\n",
    "by ayasyrev"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converted 00_constructor.ipynb.\n",
      "Converted 01_layers.ipynb.\n",
      "Converted 02_resnet.ipynb.\n",
      "Converted 03_xresnet.ipynb.\n",
      "Converted index.ipynb.\n"
     ]
    }
   ],
   "source": [
    "# hide\n",
    "from nbdev.export import *\n",
    "notebook2script()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
